Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate pyro.primitives & poutine.block_messenger #3292

Merged
merged 5 commits into from
Nov 13, 2023
Merged

Conversation

ordabayevy
Copy link
Member

No description provided.

@ordabayevy ordabayevy added the WIP label Nov 12, 2023
@ordabayevy ordabayevy changed the title Type annotate primitives & poutine.block_messenger Type annotate pyro.primitives & poutine.block_messenger Nov 12, 2023

import torch
from torch.distributions import constraints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see lots of new dependencies here. Historically we've had a number of cyclic dependency issues in Pyro. One thing we might consider to try to avoid cyclid dependencies is to guard these with an if TYPE_CHECKING:

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch.distributions import contstraints
    from pyro.distributions import TorchDistribution
    from pyro.params.param_store import ParamStoreDict
    from pyro.poutine.runtime import Message

Or maybe we can just use that trick next time we need to fix a cyclic dependency. Either way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea. However, in this case it turns out that make docs fails without TorchDistribution and ParamStoreDict imported. And constraints is used in actual code.

@@ -51,7 +51,7 @@ def __call__(self, sample_shape=torch.Size()):
)

@property
def event_dim(self):
def event_dim(self) -> int:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this file I type annotated only methods needed for pyro.primitives


def effectful(
fn: Optional[Callable[P, T]] = None, type: Optional[str] = None
) -> Callable:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the best I could do at being specific with callable types.

def sample(name, fn, *args, **kwargs):
def sample(
name: str,
fn: TorchDistribution,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct or does fn can be any callable returning torch.Tensor?

*args,
obs: Optional[torch.Tensor] = None,
obs_mask: Optional[torch.Tensor] = None,
infer: Optional[Dict[str, Union[str, bool]]] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made these args explicit.

@@ -374,7 +406,9 @@ def __init__(self, *args, **kwargs):


@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
def plate_stack(
prefix: str, sizes: Sequence[int], rightmost_dim: int = -1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says sizes is iterable, however, iterable is not reversible.

@@ -462,7 +498,7 @@ def module(name, nn_module, update_module_params=False):
param_name
] = target_state_dict[_name]
else:
nn_module._parameters[mod_name] = target_state_dict[_name]
nn_module._parameters[mod_name] = target_state_dict[_name] # type: ignore[assignment]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn_module._parameters's type is nn.Parameter and target_state_dict[_name] is torch.Tensor.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@fritzo fritzo merged commit e274bca into dev Nov 13, 2023
9 checks passed
@ordabayevy ordabayevy deleted the type-primitives branch November 13, 2023 03:14
@ordabayevy ordabayevy mentioned this pull request Feb 5, 2024
23 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants